Skip to content

Conversation

@muqing-li
Copy link

@muqing-li muqing-li commented Jan 15, 2026

Thanks for your contribution; we appreciate it a lot. The following instructions will make your pull request healthier and help you get feedback more easily. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers.
感谢您的贡献,我们非常重视。以下说明将使您的拉取请求更健康,更易于获得反馈。如果您不理解某些项目,请不要担心,只需提交拉取请求并从维护人员那里寻求帮助即可。

PR Type / PR类型

  • Feature(功能新增)
  • Bugfix(Bug 修复)
  • Docs(文档更新)
  • CI/CD(持续集成/持续部署)
  • Refactor(代码重构)
  • Perf(性能优化)
  • Dependency(依赖项更新)
  • Test-Cases(测试用例更新)
  • Other(其他)

Related Issue | 关联 Issue
Fixes #(issue ID / issue 编号) / Relates to #(issue ID / issue 编号)

🔍 Motivation / 变更动机

Please describe the motivation of this PR and the goal you want to achieve through this PR.
请描述您的拉取请求的动机和您希望通过此拉取请求实现的目标。

当前 AISBench 评测工具暂不支持直接加载 MindFormers 训练生成的模型权重进行离线评测,因此需要在 AISBench 侧对 MindFormers/MindSpore 模型进行统一封装适配,以实现兼容评测.

目标:创建 MindFormers 框架的模型加载脚本与对应的aisbench配置文件,令aisbench可通过对应命令调用mindformers模型权重进行测评:

ais_bench --models mf_model --datasets 评测任务

📝 Modifica

tion / 修改内容

Please briefly describe what modification is made in this PR.
请简要描述此拉取请求中进行的修改。

  • 1)增加配置文件:ais_bench\benchmark\configs\models\mf_models\mindformers_model.py

    在Aisbench工具的config文件中,通过设置 type: MindFormerModel 来选择mindformers的模型。与HF离线模型配置的差别如下:

     type=MindFormerModel, 
     abbr='mindformer-model'
     yaml_cfg_path = 'THUDM/your.yaml',	#path to mindformers yaml file, current value is just a example
    

    注意:yaml_cfg_path为新增配置,用来指示yaml配置路径。

    2)增加模型封装文件:ais_bench/benchmark/ais_bench/benchmark/models/mindformers_model.py,

    主要包含:model build :load_tokenizer、load_model、load_checkpoint。generete调用接口:主要集成mindformers的text_generete的调用,以及部分后处理

    • 利用aisbench的工厂机制,通过 @MODELS.register_module(),将mindformers的模型封装类注册进去,

    3)修改aisbench runner 中的逻辑,适配inference任务的subprocess调用的命令,在运行mindformer_model时,使用msrun代替torchrun

    if self.abbr == 'mindformer-model':
        command = (
            f"msrun "
            f"--worker_num={self.num_gpus} "
            f"--local_worker_num={self.node_gpus} "
            f"--master_port={port} "
            f"--log_dir='output/msrun_log' "
            f"--join=True "
            f'{script_path} {cfg_path}'
        )
    else :
        command = (f'torchrun --master_port={port} '
                   f'--nproc_per_node {self.num_procs} '
                   f'{script_path} {cfg_path}')
    

📐 Associated Test Results / 关联测试结果

Please provide links to the related test results, such as CI pipelines, test reports, etc.
请提供相关测试结果的链接,例如 CI 管道、测试报告等。

Aisbench demo执行结果

黑盒验证:HuggingFace权重转换成minfformers权重,mindformers加载转换后的权重和huggingface流程加载转换前权重得到的评分相同
相同权重经过HuggingFaceModel和MindFormerModel模型进行评测的结果,误差控制在1%内:

Qwen3_0.6B Ceval评分 (单卡

评测项目 MindFormerModel得分 HuggingFaceModel得分
ceval-stem 44.27 44.58
ceval-social-science 55.83 56.34
ceval-humanities 50.72 48.74
ceval-other 50.30 50.49
ceval-hard 39.59 40.37
ceval 49.13 48.97
ceval-weighted 48.96 48.74

Qwen3-30B-A3B MMLU得分(多卡

评测项目 MindFormerModel得分 HuggingFaceModel得分
mmlu-humanities 79.89 80.16
mmlu-stem 78.47 77.51
mmlu-social-science 86.15 86.32
mmlu-other 79.40 79.63
mmlu 80.62 80.86
mmlu-weighted 79.19 78.99

⚠️ BC-breaking (Optional) / 向后不兼容变更(可选)

Does the modification introduce changes that break the backward compatibility of the downstream repositories? If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR.
是否引入了会破坏下游存储库向后兼容性的更改?如果是,请描述它如何破坏兼容性,以及下游项目应该如何修改其代码以保持与此 PR 的兼容性。

⚠️ Performance degradation (Optional) / 性能下降(可选)

If the modification introduces performance degradation, please describe the impact of the performance degradation and the expected performance improvement.
如果引入了性能下降,请描述性能下降的影响和预期的性能改进。

🌟 Use cases (Optional) / 使用案例(可选)

If this PR introduces a new feature, it is better to list some use cases here and update the documentation.
如果此拉取请求引入了新功能,最好在此处列出一些用例并更新文档。

✅ Checklist / 检查列表

Before PR:

  • Pre-commit or other linting tools are used to fix the potential lint issues. / 使用预提交或其他 linting 工具来修复潜在的 lint 问题。
  • Bug fixes are fully covered by unit tests, the case that causes the bug should be added in the unit tests. / 修复的 Bug 已完全由单元测试覆盖,导致 Bug 的情况应在单元测试中添加。
  • The modification is covered by complete unit tests. If not, please add more unit tests to ensure the correctness. / 此拉取请求中的修改已完全由单元测试覆盖。如果不是,请添加更多单元测试以确保正确性。
  • All relevant documentation (API docs, docstrings, example tutorials) has been updated to reflect these changes. / 所有相关文档(API 文档、文档字符串、示例教程)已更新以反映这些更改。

After PR:

  • If the modification has potential influence on downstream or other related projects, this PR should be tested with those projects. / 如果此拉取请求对下游或其他相关项目有潜在影响,应在那些项目中测试此 PR。
  • CLA has been signed and all committers have signed the CLA in this PR. / CLA 已签署,且本 PR 中的所有提交者均已签署 CLA。

👥 Collaboration Info / 协作信息

  • Suggested Reviewers / 建议审核人: @xxx
  • Relevant Module Owners / 相关模块负责人: @xxx
  • Other Collaboration Notes / 其他协作说明:

🌟 Useful CI Command / 实用的CI命令

Command / 命令 Introduction / 介绍
/gemini review Performs a code review for the current pull request in its current state by Gemini. / 对当前拉取请求在当前状态下由 Gemini 执行代码审核。
/gemini summary Provides a summary of the current pull request in its current state by Gemini. / 对当前拉取请求在当前状态下由 Gemini 提供摘要。
/gemini help Displays a list of available commands of Gemini. / 显示 Gemini 可用命令的列表。
/readthedocs build Triggers a build of the documentation for the current pull request in its current state by Read the Docs. / 触发当前拉取请求在当前状态下由 Read the Docs 构建文档。

@gemini-code-assist
Copy link

Summary of Changes

Hello @muqing-li, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

此拉取请求旨在增强 AISBench 评估工具的功能,使其能够直接加载并评估由 MindFormers 框架训练生成的模型权重。通过引入 MindFormers 模型的统一封装和适配其特有的分布式运行机制,AISBench 现在可以通过简单的命令行指令对 MindFormers 模型进行离线评估,从而扩展了其支持的模型生态系统。

Highlights

  • 新增 MindFormers 模型配置: 增加了 ais_bench/benchmark/configs/models/mf_models/mf_model.py 配置文件,允许通过 type: MindFormerModel 来配置 MindFormers 模型,并引入了 yaml_cfg_file 参数指定 MindFormers 的 YAML 配置文件路径。
  • MindFormers 模型封装: 引入了 ais_bench/benchmark/models/local_models/mindformers_model.py 文件,其中包含 MindFormerModel 类,用于封装 MindFormers 模型的加载(tokenizer、model、checkpoint)和生成(text_generate)接口,并注册到 AISBench 的模型工厂机制中。
  • 适配 MindFormers 模型的运行逻辑: 修改了 ais_bench/benchmark/tasks/openicl_infer.py 中的 runner 逻辑,当模型类型为 mindformer-model 时,使用 msrun 命令替代 torchrun 来启动推理任务,以支持 MindFormers 的分布式运行。

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

感谢您的贡献。本次PR为AISBench新增了对MindFormers模型的支持,这是一个非常有价值的功能。代码整体结构清晰,但在具体实现上存在一些关键问题需要修复。主要问题集中在MindFormerModel类中,其中几个核心的评估方法(如get_ppl)错误地混合使用了PyTorch和MindSpore的API,这将导致运行时错误。此外,还存在缺失导入、裸露的except语句以及一些硬编码配置等问题。请在合并前解决这些问题,以确保代码的正确性和可维护性。

model = Model(self.model)
input_ids = Tensor(np.ones((batch_size, max_seq_len), dtype=np.int32))
infer_data = self.model.prepare_inputs_for_predict_layout(input_ids)
transform_and_load_checkpoint(config, model, self.model, infer_data, do_eval=True)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

函数transform_and_load_checkpoint在此处被调用,但它既没有在本文件中定义,也没有被导入。这将导致运行时出现NameError。请从mindformers.checkpoint.checkpoint中导入此函数。

Suggested change
transform_and_load_checkpoint(config, model, self.model, infer_data, do_eval=True)
from mindformers.checkpoint.checkpoint import transform_and_load_checkpoint
transform_and_load_checkpoint(config, model, self.model, infer_data, do_eval=True)

Comment on lines 371 to 392
outputs, inputs = self.get_logits(inputs)
shift_logits = outputs[..., :-1, :].contiguous().float()

shift_labels = inputs['tokens']['input_ids'][..., 1:].contiguous()

loss_fct = torch.nn.CrossEntropyLoss(
reduction='none', ignore_index=self.tokenizer.pad_token_id)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)).view(shift_labels.size())

if mask_length is not None:
mask = torch.zeros_like(shift_labels) # [batch,seqlen]
for i in range(len(mask)):
for j in range(mask_length[i] - 1, len(mask[i])):
mask[i][j] = 1
loss = loss * mask

lens = (inputs['tokens']['input_ids'] !=
self.tokenizer.pad_token_id).sum(-1).cpu().numpy()
if mask_length is not None:
lens -= np.array(mask_length)
ce_loss = loss.float().sum(-1).cpu().detach().numpy() / lens
return ce_loss

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

此方法以及get_loglikelihoodget_mink_percent方法中存在严重问题。它们在处理MindSpore模型输出时,错误地使用了PyTorch的API(例如contiguous().float()torch.nn.CrossEntropyLoss)。MindSpore的Tensor对象与PyTorch的Tensor对象不兼容,这将导致运行时错误。这些函数需要完全使用MindSpore的API重写,以确保其功能正确。


try:
generation_config = GenerationConfig.from_pretrained(path)
except:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

使用裸露的except:语句会捕获所有异常,包括像SystemExitKeyboardInterrupt这样的系统退出信号,这使得程序难以调试和正常中断。建议捕获更具体的异常,例如Exception,以避免意外行为。

Suggested change
except:
except Exception:

abbr='mindformer-model',
path='THUDM/chatglm-6b', # path to model dir, current value is just a example
checkpoint = 'THUDM/your_checkpoint', # path to checkpoint file, current value is just a example
yaml_cfg_file = 'THUDM//your.yaml',

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

yaml_cfg_file的路径中使用了双斜杠(//)。虽然这在某些操作系统上可能有效,但它不是标准用法,并可能在其他系统上导致路径解析问题。建议使用单个斜杠以保证路径的正确性和一致性。

Suggested change
yaml_cfg_file = 'THUDM//your.yaml',
yaml_cfg_file = 'THUDM/your.yaml',

Comment on lines 18 to 20
from mindformers.utils.load_checkpoint_utils import load_checkpoint_with_safetensors, get_load_path_after_hf_convert
from mindformers.checkpoint.checkpoint import load_checkpoint
from mindformers.checkpoint.utils import compile_model

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

此文件中导入了load_checkpoint_with_safetensorsload_checkpointcompile_model,但它们在代码中并未使用。为了保持代码整洁并避免混淆,应移除这些未使用的导入。

self.logger.info("..........Network Built Successfully..........")
self.model.set_train(False)
config.load_checkpoint = get_load_path_after_hf_convert(config, self.model)
self.logger.info("load checkpoint path : ",config.load_checkpoint)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

日志记录函数self.logger.info被传入了多个参数。通常的最佳实践是传入一个格式化好的字符串。这不仅能提高代码的可读性,也更符合标准的日志记录规范。

Suggested change
self.logger.info("load checkpoint path : ",config.load_checkpoint)
self.logger.info(f"load checkpoint path : {config.load_checkpoint}")

Comment on lines +144 to +147
if attr == "local" and abbr == "mindformer-model" :
return MODELS.build(model_cfg, batch_size = batch_size)
else :
return MODELS.build(model_cfg)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

此函数中为mindformer-model添加了特殊的处理逻辑来传递batch_size。虽然这能解决当前问题,但它降低了通用构建函数的可维护性。为每种新模型类型添加if/else分支会使代码越来越复杂。建议考虑一种更通用的机制来处理模型特定的初始化参数,而不是针对特定模型进行硬编码。

@muqing-li
Copy link
Author

/gemini review

@muqing-li muqing-li reopened this Jan 15, 2026
@muqing-li
Copy link
Author

/gemini help

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for evaluating MindFormers models within AISBench. The changes include a new MindFormerModel wrapper, a corresponding configuration file, and modifications to the task runner to use msrun for distributed execution. While the core feature of running MindFormer models seems well-implemented, I've found several critical issues in the utility methods for calculating PPL and log-likelihood, which appear to be incorrectly mixing PyTorch and MindSpore APIs, and will lead to runtime errors. Additionally, there are bugs in the multi-node command generation for msrun. I've provided detailed comments and suggestions to address these issues.

Comment on lines 353 to 392
def _get_ppl(self,
inputs: List[str],
mask_length: Optional[List[int]] = None) -> List[float]:
"""Get perplexity scores given a list of inputs.
Args:
inputs (List[str]): A list of strings.
mask_length (Optional[List[int]]): A list of mask lengths. If
provided, the perplexity scores will be calculated with the
first mask_length[i] tokens masked out. It's okay to skip
its implementation if advanced features in PPLInfernecer is
not needed.
Returns:
List[float]: A list of perplexity scores.
"""

outputs, inputs = self.get_logits(inputs)
shift_logits = outputs[..., :-1, :].contiguous().float()

shift_labels = inputs['tokens']['input_ids'][..., 1:].contiguous()

loss_fct = torch.nn.CrossEntropyLoss(
reduction='none', ignore_index=self.tokenizer.pad_token_id)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)).view(shift_labels.size())

if mask_length is not None:
mask = torch.zeros_like(shift_labels) # [batch,seqlen]
for i in range(len(mask)):
for j in range(mask_length[i] - 1, len(mask[i])):
mask[i][j] = 1
loss = loss * mask

lens = (inputs['tokens']['input_ids'] !=
self.tokenizer.pad_token_id).sum(-1).cpu().numpy()
if mask_length is not None:
lens -= np.array(mask_length)
ce_loss = loss.float().sum(-1).cpu().detach().numpy() / lens
return ce_loss

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The methods get_ppl, get_loglikelihood, and get_mink_percent (and their internal _ counterparts) are implemented using PyTorch APIs (torch.nn.CrossEntropyLoss, .contiguous(), .float(), etc.). However, the tensors they operate on (outputs from get_logits and inputs['tokens']['input_ids']) are MindSpore tensors. This mismatch will cause a runtime crash.

To fix this, you need to convert the MindSpore tensors to PyTorch tensors before using them with PyTorch functions, for example, by using .asnumpy() and torch.from_numpy().

Additionally, get_loglikelihood has further issues like using PyTorch-specific return_tensors='pt' and model.device, which will also fail with a MindSpore model.

    def _get_ppl(self,
                 inputs: List[str],
                 mask_length: Optional[List[int]] = None) -> List[float]:
        """Get perplexity scores given a list of inputs.

        Args:
            inputs (List[str]): A list of strings.
            mask_length (Optional[List[int]]): A list of mask lengths. If
                provided, the perplexity scores will be calculated with the
                first mask_length[i] tokens masked out. It's okay to skip
                its implementation if advanced features in PPLInfernecer is
                not needed.

        Returns:
            List[float]: A list of perplexity scores.
        """

        outputs, inputs_data = self.get_logits(inputs)
        # Convert MindSpore tensors to PyTorch tensors for loss calculation
        shift_logits = torch.from_numpy(outputs.asnumpy())[..., :-1, :].contiguous().float()

        input_ids_torch = torch.from_numpy(inputs_data['tokens']['input_ids'].asnumpy())
        shift_labels = input_ids_torch[..., 1:].contiguous()

        loss_fct = torch.nn.CrossEntropyLoss(
            reduction='none', ignore_index=self.tokenizer.pad_token_id)
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
                        shift_labels.view(-1)).view(shift_labels.size())

        if mask_length is not None:
            mask = torch.zeros_like(shift_labels)  # [batch,seqlen]
            for i in range(len(mask)):
                for j in range(mask_length[i] - 1, len(mask[i])):
                    mask[i][j] = 1
            loss = loss * mask

        lens = (input_ids_torch !=
                self.tokenizer.pad_token_id).sum(-1).cpu().numpy()
        if mask_length is not None:
            lens -= np.array(mask_length)
        ce_loss = loss.float().sum(-1).cpu().detach().numpy() / lens
        return ce_loss

Comment on lines 82 to 95
if self.abbr == "mindformer-model" :
command = (
f"msrun "
f"--worker_num={self.num_procs} "
f"--local_worker_num={self.local_worker_num} "
f"--master_port={port} "
f"--master_addr={self.master_addr} "
f"--node_rank={self.node_rank} "
f"--log_dir='output/msrun_log' "
f"--join=True "
f"--cluster_time_out=7200 "
f'{script_path} {cfg_path}'
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The msrun command for multi-node execution (nnodes > 1) has incorrect parameters:

  1. Incorrect worker_num: It's set to self.num_procs, which is the number of processes per node. It should be the total number of workers across all nodes (self.num_procs * self.nnodes).
  2. Incorrect local_worker_num: It's set to self.local_worker_num, which defaults to 0 if not provided in the config. This is incorrect. It should be the number of workers on the current node, which is self.num_procs.
            if self.abbr == "mindformer-model" :
                total_workers = self.num_procs * self.nnodes
                local_workers = self.local_worker_num if self.local_worker_num > 0 else self.num_procs
                command = (
                    f"msrun "
                    f"--worker_num={total_workers} "
                    f"--local_worker_num={local_workers} "
                    f"--master_port={port} "
                    f"--master_addr={self.master_addr} "
                    f"--node_rank={self.node_rank} "
                    f"--log_dir='output/msrun_log' "
                    f"--join=True "
                    f"--cluster_time_out=7200 "
                    f'{script_path} {cfg_path}'
                )

self.batch_size = batch_size
self.pad_token_id = pad_token_id
self.pretrained_model_path = path
assert mode in ['none', 'mid']

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using assert for validating input parameters is not recommended for production code, as assertions can be disabled with the -O flag, which would bypass this check. It's more robust to raise a ValueError for invalid inputs to ensure the validation is always performed.

Suggested change
assert mode in ['none', 'mid']
if mode not in ['none', 'mid']:
raise ValueError(f"mode must be 'none' or 'mid', but got {mode}")

self.logger.info("..........Network Built Successfully..........")
self.model.set_train(False)
config.load_checkpoint = get_load_path_after_hf_convert(config, self.model)
self.logger.info("load checkpoint path : ",config.load_checkpoint)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logger call self.logger.info("load checkpoint path : ", config.load_checkpoint) passes multiple arguments. A standard Python logger will print this as a tuple, e.g., ('load checkpoint path : ', 'your/path'). To format the log message correctly, you should use an f-string or other string formatting methods.

Suggested change
self.logger.info("load checkpoint path : ",config.load_checkpoint)
self.logger.info(f"load checkpoint path : {config.load_checkpoint}")

Comment on lines 173 to 174
except ValueError :
raise ValueError('Failed to load MindFormers model, please check configuration')

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The except ValueError: clause catches the exception but doesn't include the original exception when raising a new one. This discards the original stack trace and error context, making debugging more difficult. It's better to use raise ... from e to preserve the exception chain.

Suggested change
except ValueError :
raise ValueError('Failed to load MindFormers model, please check configuration')
except ValueError as e:
raise ValueError('Failed to load MindFormers model, please check configuration') from e

Comment on lines 65 to 95
if self.abbr == 'mindformer-model':
command = (
f"msrun "
f"--worker_num={self.num_gpus} "
f"--local_worker_num={self.num_gpus} "
f"--master_port={port} "
f"--log_dir='output/msrun_log' "
f"--join=True "
f"--cluster_time_out=7200 "
f'{script_path} {cfg_path}'
)
else :
command = (f'torchrun --master_port={port} '
f'--nproc_per_node {self.num_procs} '
f'{script_path} {cfg_path}')
elif self.nnodes > 1:
port = 12345
command = (f'torchrun --master_port={port} '
if self.abbr == "mindformer-model" :
command = (
f"msrun "
f"--worker_num={self.num_procs} "
f"--local_worker_num={self.local_worker_num} "
f"--master_port={port} "
f"--master_addr={self.master_addr} "
f"--node_rank={self.node_rank} "
f"--log_dir='output/msrun_log' "
f"--join=True "
f"--cluster_time_out=7200 "
f'{script_path} {cfg_path}'
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for constructing the msrun command is duplicated between the single-node multi-GPU case and the multi-node case. This reduces maintainability, as changes need to be applied in two places. Consider refactoring this into a helper function to build the msrun command string.

Comment on lines +144 to +147
if attr == "local" and abbr == "mindformer-model" :
return MODELS.build(model_cfg, batch_size = batch_size)
else :
return MODELS.build(model_cfg)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The code adds special handling for mindformer-model by checking its attr and abbr to pass batch_size to its constructor. This creates a tight coupling and makes the build process harder to maintain. If other models need similar treatment in the future, this if/else block will grow.

A more scalable approach would be to avoid this special casing. For instance, you could consider not popping batch_size from the config and letting models that don't need it ignore it via **kwargs. If popping it is necessary for other models, a more generic mechanism (e.g., a flag in the model config) would be preferable to hardcoding model names.

@gemini-code-assist
Copy link

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

@muqing-li muqing-li force-pushed the mf_model branch 2 times, most recently from 9377abd to 3e27b24 Compare January 15, 2026 12:13
@muqing-li
Copy link
Author

/gemini review

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

此 PR 成功地为 AISBench 添加了对 MindFormers 模型的支持,包括新的模型封装、配置以及对 msrun 启动器的适配。代码结构清晰,实现了预定目标。

我提出了一些改进建议,主要集中在以下几个方面:

  • 代码健壮性:mindformers_model.py 中,对可能为 Noneyaml_cfg_file 进行检查,并使用更合适的 ValueError 替代 assert 来处理无效输入。
  • 可维护性: 移除了未使用的导入,避免了硬编码值(如 fastchat 模板),并建议重构 openicl_infer.py 中重复的 msrun 命令构建逻辑。
  • 代码清晰度: 建议为 build.py 中的特殊处理逻辑添加注释,并修正配置文件中具有误导性的注释。
  • 性能: 优化了 generate 方法中停止字符串的后处理逻辑,以避免不必要的列表重复创建。

这些修改将有助于提高代码的健壮性、可维护性和性能。总体而言,这是一次很棒的贡献。

if mode not in ['none', 'mid']:
raise ValueError(f"mode must be 'none' or 'mid', but got {mode}")
self.mode = mode
self.config = MindFormerConfig(yaml_cfg_file)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

yaml_cfg_fileNone 时,MindFormerConfig(yaml_cfg_file) 可能会失败。__init__ 函数签名允许 yaml_cfg_fileNone,因此在调用 MindFormerConfig 之前应该添加一个检查,以确保 yaml_cfg_file 不是 None,或者在 yaml_cfg_fileNone 时进行适当的错误处理。

例如:

if not yaml_cfg_file:
    raise ValueError('`yaml_cfg_file` is required for MindFormerModel')
self.config = MindFormerConfig(yaml_cfg_file)

seed = None,
repetition_penalty = 1.03,
),
run_cfg = dict(num_gpus=1, num_procs=1), # 多卡/多机多卡 参数,使用torchrun拉起任务

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

此处的注释提到了 torchrun,但对于 MindFormer 模型,实际使用的是 msrun。这可能会误导用户。建议将注释更新为 msrun 或更通用的描述,以反映实际使用的启动器。

Suggested change
run_cfg = dict(num_gpus=1, num_procs=1), # 多卡/多机多卡 参数,使用torchrun拉起任务
run_cfg = dict(num_gpus=1, num_procs=1), # 多卡/多机多卡 参数,使用msrun拉起任务

from typing import Dict, List, Optional, Union

import numpy as np
import torch

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

torch 库被导入但未在文件中使用。为了保持代码整洁和避免不必要的依赖,建议移除这个未使用的导入。

Comment on lines 108 to 111
kwargs = DEFAULT_TOKENIZER_KWARGS
kwargs.update(tokenizer_kwargs)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

DEFAULT_TOKENIZER_KWARGS 是一个可变字典。直接赋值 kwargs = DEFAULT_TOKENIZER_KWARGS 然后调用 kwargs.update() 会修改 DEFAULT_TOKENIZER_KWARGS 本身,这可能在后续调用中导致意外的副作用。建议在赋值时创建一个副本以避免这种情况。
此外,将 DEFAULT_TOKENIZER_KWARGS 定义为类级别的常量会更清晰。

Suggested change
kwargs = DEFAULT_TOKENIZER_KWARGS
kwargs.update(tokenizer_kwargs)
kwargs = DEFAULT_TOKENIZER_KWARGS.copy()
kwargs.update(tokenizer_kwargs)


try:
generation_config = GenerationConfig.from_pretrained(path)
except Exception:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

使用 except Exception: 过于宽泛,它会捕获所有类型的异常,可能会掩盖非预期的错误。建议捕获更具体的异常,例如 OSErrortransformers 库在模型加载失败时可能抛出的特定异常,以提高代码的健壮性和可维护性。

"'pip install \"fschat[model_worker,webui]\"' "
'to implement fastchat.')
for idx, text in enumerate(messages):
conv = get_conversation_template('vicuna')

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

在这里,对话模板被硬编码为 'vicuna'。这限制了模型的灵活性,因为不同的模型可能需要不同的对话模板。建议将模板名称作为可配置的参数,例如通过 generation_kwargs__init__ 的新参数传入,以支持更广泛的模型。

conv.append_message(conv.roles[1], None)
messages[idx] = conv.get_prompt()
if self.mode == 'mid':
assert len(messages) == 1

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

assert 语句通常用于检查开发过程中的内部不变量,而不应用于验证运行时输入。如果 mid 模式下 messages 列表的长度不为 1,assert 将会使程序崩溃,且不会提供清晰的错误信息。建议使用 if 条件判断并抛出 ValueError,向用户提供更明确的错误说明。

Suggested change
assert len(messages) == 1
if len(messages) != 1:
raise ValueError(f"The 'mid' mode only supports a batch size of 1, but got {len(messages)}")

Comment on lines +294 to +305
if self.end_str:
decodeds = [text.split(self.end_str)[0] for text in decodeds]
if origin_stopping_criteria:
for token in origin_stopping_criteria:
decodeds = [text.split(token)[0] for text in decodeds]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

当前对 end_strorigin_stopping_criteria 的后处理方式在循环中重复创建列表 (decodeds = [...]),当 decodeds 列表或停止标记列表很大时,效率较低。建议将所有停止标记组合起来,并在一次遍历中处理每个解码后的字符串,以提高性能。

Suggested change
if self.end_str:
decodeds = [text.split(self.end_str)[0] for text in decodeds]
if origin_stopping_criteria:
for token in origin_stopping_criteria:
decodeds = [text.split(token)[0] for text in decodeds]
all_stop_tokens = []
if self.end_str:
all_stop_tokens.append(self.end_str)
all_stop_tokens.extend(origin_stopping_criteria)
if all_stop_tokens:
for i in range(len(decodeds)):
text = decodeds[i]
for token in all_stop_tokens:
text = text.split(token)[0]
decodeds[i] = text

Comment on lines 65 to 96
if self.abbr == 'mindformer-model':
command = (
f"msrun "
f"--worker_num={self.num_gpus} "
f"--local_worker_num={self.num_gpus} "
f"--master_port={port} "
f"--log_dir='output/msrun_log' "
f"--join=True "
f"--cluster_time_out=7200 "
f'{script_path} {cfg_path}'
)
else :
command = (f'torchrun --master_port={port} '
f'--nproc_per_node {self.num_procs} '
f'{script_path} {cfg_path}')
elif self.nnodes > 1:
port = 12345
command = (f'torchrun --master_port={port} '
if self.abbr == "mindformer-model" :
command = (
f"msrun "
f"--worker_num={self.num_procs} "
f"--local_worker_num={self.local_worker_num} "
f"--master_port={port} "
f"--master_addr={self.master_addr} "
f"--node_rank={self.node_rank} "
f"--log_dir='output/msrun_log' "
f"--join=True "
f"--cluster_time_out=7200 "
f'{script_path} {cfg_path}'
)
else :

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

在单节点多卡 (self.nnodes == 1) 和多节点 (self.nnodes > 1) 的情况下,构造 msrun 命令的逻辑存在大量重复代码。这使得代码难以维护和修改。建议将 msrun 命令的构建逻辑提取到一个独立的辅助函数中,以减少重复并提高代码的可读性。

Comment on lines +144 to +147
if attr == "local" and abbr == "mindformer-model" :
return MODELS.build(model_cfg, batch_size = batch_size)
else :
return MODELS.build(model_cfg)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

此处为 mindformer-model 添加了特殊的构建逻辑,以便在模型初始化时传递 batch_size。虽然这可能是必要的,但它为代码引入了一个特例,降低了可维护性。建议在此处添加一条注释,解释为什么 MindFormerModel 需要在 __init__ 时获取 batch_size(例如,用于模型加载时的占位符张量),以帮助其他开发者理解这个特殊处理的原因。

@muqing-li
Copy link
Author

/gemini review

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

感谢您的贡献,这个PR为AISBench添加了对MindFormers模型的支持,这是一个很棒的功能。代码整体结构清晰,成功地集成了msrun来启动MindFormers模型。

我在代码中发现了一些可以改进的地方,主要集中在以下几个方面:

  1. 多节点执行的Bug:在 openicl_infer.py 中,为 msrun 构建多节点命令时,worker_num 参数设置错误,可能导致多节点任务失败。
  2. 代码健壮性:在 mindformers_model.py 中,存在一些可能导致运行时错误或非预期行为的问题,例如可变默认参数的修改、不安全的字典访问和应使用 ValueErrorassert
  3. 代码可维护性:在 build.py 中为MindFormers模型添加了特殊处理逻辑,这会降低代码的长期可维护性。
  4. 性能mindformers_model.py 中的字符串后处理逻辑可以更高效。

具体的修改建议请见我的评论。修复这些问题后,代码将更加健robust和易于维护。

if self.abbr == "mindformer-model" :
command = (
f"msrun "
f"--worker_num={self.num_procs} "

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

在多机多卡(self.nnodes > 1)场景下,msrun--worker_num 参数应该设置为总的进程数,即 nproc_per_node * nnodes。当前代码中,它被错误地设置为了 self.num_procs(即 nproc_per_node),这将导致 msrun 无法在所有节点上正确启动所有进程。

Suggested change
f"--worker_num={self.num_procs} "
f"--worker_num={self.num_procs * self.nnodes} "

Comment on lines 110 to 111
kwargs = DEFAULT_TOKENIZER_KWARGS
kwargs.update(tokenizer_kwargs)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

在这里,kwargs 直接赋值为 DEFAULT_TOKENIZER_KWARGS,它是一个字典。随后的 kwargs.update(tokenizer_kwargs) 操作会直接修改 DEFAULT_TOKENIZER_KWARGS 本身。这会导致在多次调用 _load_tokenizer 时,DEFAULT_TOKENIZER_KWARGS 的值会累积之前的修改,可能引发意想不到的副作用。为了避免这种情况,应该使用 .copy() 方法创建一个副本。

Suggested change
kwargs = DEFAULT_TOKENIZER_KWARGS
kwargs.update(tokenizer_kwargs)
kwargs = DEFAULT_TOKENIZER_KWARGS.copy()
kwargs.update(tokenizer_kwargs)

Comment on lines +277 to +278
if isinstance(outputs, dict):
outputs = outputs.get('sequences', outputs)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

这里的代码 outputs = outputs.get('sequences', outputs) 存在风险。如果 outputs 是一个字典但没有 'sequences' 键,get 方法会返回默认值 outputs(即字典本身)。这会导致下一行 [seq.tolist() for seq in outputs] 对字典进行迭代,从而引发 AttributeError。建议修改为更安全的处理方式,确保在访问 'sequences' 失败时能正确处理。

Suggested change
if isinstance(outputs, dict):
outputs = outputs.get('sequences', outputs)
if isinstance(outputs, dict):
outputs = outputs.get('sequences')
if outputs is None:
raise ValueError("Model output dictionary is missing 'sequences' key.")

command = (
f"msrun "
f"--worker_num={self.num_procs} "
f"--local_worker_num={self.local_worker_num} "

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

在多机多卡场景下,--local_worker_num 参数使用了从配置中读取的 self.local_worker_num。然而,这个新配置项在 __init__ 中默认值为0 (run_cfg.get('local_worker_num', 0))。如果用户没有在配置文件中显式设置 local_worker_nummsrun 会收到 --local_worker_num=0,这很可能会导致启动失败。

为了提高健壮性,建议直接使用 self.num_procs(即每节点的进程数)作为 local_worker_num 的值。这与单机多卡场景下的逻辑也更一致。

Suggested change
f"--local_worker_num={self.local_worker_num} "
f"--local_worker_num={self.num_procs} "

from typing import Dict, List, Optional, Union

import numpy as np
import torch

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

此文件中导入了 torch 模块,但在整个文件中并未使用。为了保持代码整洁和避免不必要的依赖,建议移除未使用的导入。

conv.append_message(conv.roles[1], None)
messages[idx] = conv.get_prompt()
if self.mode == 'mid':
assert len(messages) == 1

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

mode == 'mid' 的情况下,使用了 assert 来检查 messages 列表的长度。assert 语句在生产环境中可能会被禁用(例如,使用 -O 标志运行 Python),这会导致校验失效。建议使用 raise ValueError 来进行参数校验,并提供更明确的错误信息,这样更健壮。

Suggested change
assert len(messages) == 1
if len(messages) != 1:
raise ValueError(f"In 'mid' mode, expected 1 message, but got {len(messages)}")

Comment on lines +301 to +305
if origin_stopping_criteria:
for token in origin_stopping_criteria:
decodeds = [text.split(token)[0] for text in decodeds]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

这段代码通过多次循环来移除多个可能的停止词(origin_stopping_criteria)。每次循环都会创建一个新的列表,当停止词较多或解码文本很长时,效率较低。可以优化为单次遍历,对每个文本找到第一个出现的停止词并截断。

Suggested change
if origin_stopping_criteria:
for token in origin_stopping_criteria:
decodeds = [text.split(token)[0] for text in decodeds]
if origin_stopping_criteria:
new_decodeds = []
for text in decodeds:
first_stop_pos = len(text)
for token in origin_stopping_criteria:
pos = text.find(token)
if pos != -1:
first_stop_pos = min(first_stop_pos, pos)
new_decodeds.append(text[:first_stop_pos])
decodeds = new_decodeds

Comment on lines +144 to +147
if attr == "local" and abbr == "mindformer-model" :
return MODELS.build(model_cfg, batch_size = batch_size)
else :
return MODELS.build(model_cfg)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

此处为 mindformer-model 添加了特殊的逻辑来传递 batch_size 参数。虽然这解决了当前的需求,但这种硬编码使得代码难以维护。如果未来有其他模型也需要在初始化时接收 batch_size 或其他参数,就需要不断修改此处的 if-else 逻辑。建议考虑一种更通用的解决方案,例如通过反射检查模型类的 __init__ 方法签名,动态地传递支持的参数。

@muqing-li
Copy link
Author

/gemini review

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

本次 PR 新增了对 MindFormers 模型的支持,使得 ais_bench 可以直接加载 MindFormers 模型进行评测。主要改动包括新增 MindFormerModel 模型封装、对应的配置文件以及修改任务启动逻辑以适配 msrun。整体实现很完整,代码质量很高。我发现了一些可以改进的地方,主要集中在代码可读性和可维护性方面,希望能帮助提升代码质量。

if isinstance(outputs, dict):
outputs = outputs.get('sequences', outputs)
if outputs is None:
raise ValueError("Model output dictionary is missing 'seqence' key.")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

错误信息中存在拼写错误,seqence 应为 sequence。正确的键名对于调试问题至关重要。

Suggested change
raise ValueError("Model output dictionary is missing 'seqence' key.")
raise ValueError("Model output dictionary is missing 'sequence' key.")

seed = None,
repetition_penalty = 1.03,
),
run_cfg = dict(num_gpus=1, num_procs=1), # 多卡/多机多卡 参数,使用torchrun拉起任务

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

此处的注释说明“使用torchrun拉起任务”,但根据 openicl_infer.py 中的实现,MindFormer 模型将会使用 msrun。这个注释有误导性,建议修改为 msrun 以保持一致。

Suggested change
run_cfg = dict(num_gpus=1, num_procs=1), # 多卡/多机多卡 参数,使用torchrun拉起任务
run_cfg = dict(num_gpus=1, num_procs=1), # 多卡/多机多卡 参数,使用msrun拉起任务

Comment on lines +66 to +80
if self.abbr == 'mindformer-model':
command = (
f"msrun "
f"--worker_num={self.num_gpus} "
f"--local_worker_num={self.num_gpus} "
f"--master_port={port} "
f"--log_dir='output/msrun_log' "
f"--join=True "
f"--cluster_time_out=7200 "
f'{script_path} {cfg_path}'
)
else :
command = (f'torchrun --master_port={port} '
f'--nproc_per_node {self.num_procs} '
f'{script_path} {cfg_path}')

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

此处构建 msrun 命令的逻辑与 elif self.nnodes > 1 分支(83-95行)中的逻辑高度相似。这样的代码重复会增加维护难度。建议重构这部分代码以减少重复,例如,可以将 msrun 命令的公共部分提取出来,或者创建一个辅助函数来构建命令。

Comment on lines +72 to +74
f"--log_dir='output/msrun_log' "
f"--join=True "
f"--cluster_time_out=7200 "

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

这里的 log_dir ('output/msrun_log') 和 cluster_time_out (7200) 是硬编码的。为了提高灵活性,建议将这些值设为可配置项,例如从 run_cfg 中读取。这样用户就可以在不修改代码的情况下轻松更改它们。此建议同样适用于 91-93 行。

@muqing-li
Copy link
Author

/gemini review

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

感谢您的贡献,此 PR 为 ais_bench 增加了对 MindFormers 模型的支持,这是一个很棒的功能。整体实现很全面,包括了模型加载、推理和多卡运行的适配。

我主要针对代码的可维护性、健壮性和风格提出了一些建议,希望能帮助您进一步提升代码质量:

  • 代码重构:在 openicl_infer.py 中,生成 msruntorchrun 命令的逻辑存在较多重复,建议进行重构以提高可读性和可维护性。
  • 正确性:在 mindformers_model.py 中,对生成文本进行后处理的逻辑在处理多个停止标记时可能存在问题,我提供了一个更健壮的实现方式。
  • 代码质量:我还发现了一些可以改进的地方,例如移除未使用的导入、将方法内的常量提升为类常量、使用 ValueError 代替 assert 进行参数校验等。

总的来说,这是一次高质量的提交。期待这些修改合并后能让工具更加完善。

Comment on lines +301 to +305
if self.end_str:
decodeds = [text.split(self.end_str)[0] for text in decodeds]
if origin_stopping_criteria:
for token in origin_stopping_criteria:
decodeds = [text.split(token)[0] for text in decodeds]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

当前对 end_strorigin_stopping_criteria 的处理方式(通过多次循环和 split)效率较低,并且在停止标记(stop token)相互包含或顺序不同时可能导致不正确的结果。建议重构此逻辑,通过单次遍历找到第一个出现的停止标记并进行截断,这样更高效且健壮。

Suggested change
if self.end_str:
decodeds = [text.split(self.end_str)[0] for text in decodeds]
if origin_stopping_criteria:
for token in origin_stopping_criteria:
decodeds = [text.split(token)[0] for text in decodeds]
all_stop_tokens = ([self.end_str] if self.end_str else []) + origin_stopping_criteria
if all_stop_tokens:
new_decodeds = []
for text in decodeds:
min_index = len(text)
for token in all_stop_tokens:
if not token:
continue
idx = text.find(token)
if idx != -1:
min_index = min(min_index, idx)
new_decodeds.append(text[:min_index])
decodeds = new_decodeds

@@ -0,0 +1,306 @@
import os, sys

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

为了提高代码的可读性,建议遵循 PEP 8 风格指南,每行只导入一个模块。

Suggested change
import os, sys
import os
import sys

from typing import Dict, List, Optional, Union

import numpy as np
import torch

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

torch 模块被导入但未在文件中使用。为了保持代码整洁,建议移除未使用的导入。

tokenizer_kwargs: dict):
from transformers import AutoTokenizer, GenerationConfig

DEFAULT_TOKENIZER_KWARGS = dict(padding_side='left', truncation_side='left', trust_remote_code=True)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

DEFAULT_TOKENIZER_KWARGS 是一个常量,最好在类级别(MindFormerModel 内部,__init__ 方法外部)定义,而不是在 _load_tokenizer 方法内部。这样可以提高代码的可读性和可维护性,并明确其作用域。

例如:

class MindFormerModel(BaseModel):
    DEFAULT_TOKENIZER_KWARGS = dict(...)

    def __init__(self, ...):
        ...

    def _load_tokenizer(self, ...):
        kwargs = self.DEFAULT_TOKENIZER_KWARGS.copy()
        ...

conv.append_message(conv.roles[1], None)
messages[idx] = conv.get_prompt()
if self.mode == 'mid':
assert len(messages) == 1

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

使用 assert 语句进行输入验证不是最佳实践,因为断言在生产环境中可能会被禁用(例如,使用 -O 标志运行 Python)。建议改用 ValueError 来进行运行时检查,并提供更明确的错误信息。

Suggested change
assert len(messages) == 1
if len(messages) != 1:
raise ValueError(f"Expected a single message in 'mid' mode, but got {len(messages)}.")

Comment on lines +144 to +147
if attr == "local" and abbr == "mindformer-model" :
return MODELS.build(model_cfg, batch_size = batch_size)
else :
return MODELS.build(model_cfg)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

此处为 mindformer-model 添加了特殊的处理逻辑,以便将 batch_size 传递给模型构造函数。虽然这能解决问题,但在通用的 build_model_from_cfg 函数中引入模型特定的逻辑会降低代码的可维护性。未来如果新增更多需要特殊处理的模型,此处的逻辑会变得更加复杂。建议考虑是否有更通用的设计,例如,通过一个配置项来控制哪些参数不应从 model_cfg 中移除。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant